{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# deepseek-GAOKAO微调实战"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一、环境配置及DeepSeek-R1-Distill-Llama-8B下载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# python -m venv ds\n",
    "# pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 或者其他版本\n",
    "# pip install unsloth\n",
    "# pip install wandb\n",
    "# pip install modelscope\n",
    "# mkdir DeepSeek-R1-Distill-Llama-8B\n",
    "# modelscope download --model deepseek-ai/DeepSeek-R1-Distill-Llama-8B --local_dir ./DeepSeek-R1-Distill-Llama-8B"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 二、模型加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.6.0+cu124\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)\n",
    "print(torch.cuda.is_available())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n",
      "🦥 Unsloth Zoo will now patch everything to make training faster!\n"
     ]
    }
   ],
   "source": [
    "from unsloth import FastLanguageModel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DeepSeek-R1-Distill-Llama-8B 推理时至少需要约16GB RAM + 8GB 显存\n",
    "\n",
    "命令行nvidia-smi 输入后查看memory-usage 可检查自己显存剩余 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#模型一些参数配置\n",
    "max_seq_length = 2048 #序列最长限制\n",
    "dtype = None \n",
    "load_in_4bit = False\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth 2025.2.9: Fast Llama patching. Transformers: 4.48.3.\n",
      "   \\\\   /|    GPU: NVIDIA GeForce RTX 4090 D. Max memory: 23.643 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
      "\\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]\n",
      " \"-____-\"     Free Apache license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aeecffa1d12742c284b5b23609a2445f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./DeepSeek-R1-Distill-Llama-8B does not have a padding token! Will use pad_token = <|finetune_right_pad_id|>.\n"
     ]
    }
   ],
   "source": [
    "#DeepSeek-R1-Distill-Llama-8B 更适用于英语\n",
    "model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name = \"./DeepSeek-R1-Distill-Llama-8B\",\n",
    "    max_seq_length = max_seq_length,\n",
    "    dtype = dtype,\n",
    "    load_in_4bit = load_in_4bit,\n",
    "    device_map={\"\": device},  # 将所有参数加载到指定设备\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(tokenizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "调整模型为推理模式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# FastLanguageModel.for_inference(model) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试问答推理功能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #提问\n",
    "# question = \"2. (5 分) 已知复数 $z=\\\\frac{\\\\sqrt{3}+i}{(1-\\\\sqrt{3} i)^{2}}, \\\\bar{z}$ 是 $z$ 的共轭复数, 则 $z\\\\cdot\\bar{z}=(\\\\quad)$\\nA. $\\\\frac{1}{4}$\\nB. $\\\\frac{1}{2}$\\nC. 1\\nD. 2\\n\"\n",
    "# #借助分词器，将输入的问题转化为标记索引：\n",
    "# inputs = tokenizer([question], return_tensors=\"pt\").to(\"cuda\")\n",
    "# print(inputs)\n",
    "# #输入模型进行推理\n",
    "# outputs = model.generate(\n",
    "#     input_ids=inputs.input_ids,\n",
    "#     max_new_tokens=1200,\n",
    "#     use_cache=True,\n",
    "# )\n",
    "# #得到回答也是token索引\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #将回答的token索引串转换为token串\n",
    "# response = tokenizer.batch_decode(outputs)\n",
    "# print(response[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 三、GAOKAO数据集处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from datasets import Dataset\n",
    "# import wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "设置训练问答模板"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_prompt_style = \"\"\"\n",
    "    ### 提示:\n",
    "    你是一个对于解答高考题目有丰富经验的专家，现在有人问你关于{}。\n",
    "    请回答下面的问题，在回答问题之前请给出逐步的推理过程。\n",
    "\n",
    "    ### 问题：\n",
    "    {}\n",
    "\n",
    "    ### 回答：\n",
    "    <think>\n",
    "    {}\n",
    "    </think>\n",
    "    <answer>\n",
    "    {}\n",
    "    </answer>\n",
    "    \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def formatting_prompts_func(data):\n",
    "\n",
    "    EOS_TOKEN = tokenizer.eos_token\n",
    "    keywords = data[\"keywords\"]\n",
    "    inputs = data[\"Question\"]\n",
    "    cots = data[\"Complex_CoT\"]\n",
    "    outputs = data[\"Response\"]\n",
    "    texts = []\n",
    "    for k,i,c,o in zip(keywords, inputs, cots, outputs):\n",
    "        text = train_prompt_style.format(k, i, c, o) + EOS_TOKEN\n",
    "        texts.append(text)\n",
    "    return {\n",
    "        \"text\": texts,\n",
    "        }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_GAOKAO_data(root_folder):\n",
    "    #读取root_folder下所有文件\n",
    "    #返回初步训练数据，包含'keywords' 'Question' 'Complex_CoT' 'Response' 四个字段\n",
    "\n",
    "    train_data = [] \n",
    "    for foldername, _ , filenames in os.walk(root_folder):\n",
    "        for filename in filenames:\n",
    "            if filename.endswith('.json'):  # 确保只处理 JSON 文件\n",
    "                file_path = os.path.join(foldername, filename)  # 获取完整文件路径\n",
    "                with open(file_path, 'r', encoding='utf-8') as file:\n",
    "                    file_content = file.read()  # 读取文件内容\n",
    "                    data_dict = json.loads(file_content)  # 加载 JSON 文件内容\n",
    "                    k = data_dict[\"keywords\"]\n",
    "                    examples = data_dict[\"example\"]\n",
    "                    for example in examples:\n",
    "                        q = example[\"question\"]\n",
    "                        \n",
    "                        # ans是一个list，转换为字符串\n",
    "                        ans = example[\"answer\"]\n",
    "                        ans = \", \".join(ans) \n",
    "\n",
    "                        cot = example[\"analysis\"]\n",
    "                        tmp_dict = {\"keywords\": k,\n",
    "                                    \"Question\": q,\n",
    "                                    \"Complex_CoT\": cot,\n",
    "                                    \"Response\": ans\n",
    "                                    } \n",
    "                        train_data.append(tmp_dict)  # 将内容添加到列表中\n",
    "\n",
    "    return train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "【解答】 答案 A． was/were  doing，表示过去的某个时间点或时间段正在做某事\n",
      "，根据句意，我没有读完简爱，我昨天一天一直在写家庭作业． 故选 A． \n",
      "【点评】\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36ee41a4ab8a40c29057e8b445d571a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/2811 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#测试一下功能\n",
    "\n",
    "root_folder = \"GAOKAO\"  # 替换为你的根文件夹路径\n",
    "train_data = read_GAOKAO_data(root_folder)\n",
    "print(train_data[0][\"Complex_CoT\"])\n",
    "\n",
    "# 要将普通list数据转换为huggingface的Dataset，方便调用各类数据处理函数\n",
    "data_dict = {key: [item[key] for item in train_data] for key in train_data[0].keys()}\n",
    "# 使用 Dataset.from_dict() 创建 Dataset 对象\n",
    "train_data = Dataset.from_dict(data_dict)\n",
    "\n",
    "train_data = train_data.map(formatting_prompts_func, batched = True,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    ### 提示:\n",
      "    你是一个对于解答高考题目有丰富经验的专家，现在有人问你关于2010-2022_Political_Science_MCQs。\n",
      "    请回答下面的问题，在回答问题之前请给出逐步的推理过程。\n",
      "\n",
      "    ### 问题：\n",
      "    1．（ 3分）按照中国一东盟自由贸易协议， 成员国 90%的贸易商品实行零关税 。\n",
      "如果以前一件 10人民币元的 M商品出口到某东盟成员国 N国的关税为 5%，\n",
      "本外币间的汇率为 l：8.2010年该商品实行零关税， 中国生产 M商品的劳动\n",
      "生产率提高 25%，其他条件不变 ，则一件 M商品在实行零关税之前和之后出\n",
      "口到 N国的价格用 N国货币单位表示分别为（ 　　） \n",
      "A．80，84 B．84，80 C．84.64  D．84，100\n",
      "\n",
      "\n",
      "    ### 回答：\n",
      "    <think>\n",
      "    C正确，实行零关税前， 因为汇率为 1：8，关税为 5%，所以 M商品用\n",
      "N国货币表示价格为（ 10×8）×（1+5%）=84．实行零关税后，因为劳动生\n",
      "产率（社会劳动生产率 ）提高 25%，且零关税 ，所以价格为 （10/1.25）×8=64\n",
      "．故答案为 C； \n",
      "ABD均不正确，故排除。  \n",
      "故选： C。\n",
      "\n",
      "    </think>\n",
      "    <answer>\n",
      "    C\n",
      "    </answer>\n",
      "    <｜end▁of▁sentence｜>\n"
     ]
    }
   ],
   "source": [
    "print(train_data[0][\"text\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "#把数据处理和读取数据封装成一个函数\n",
    "def get_train_data(root_folder):\n",
    "    root_folder = \"GAOKAO\"  # 替换为你的根文件夹路径\n",
    "    train_data = read_GAOKAO_data(root_folder)\n",
    "    #转换为huggingface数据集，方便使用封装的各种数据处理方法\n",
    "    \n",
    "    # 将列表转换为字典格式\n",
    "    data_dict = {key: [item[key] for item in train_data] for key in train_data[0].keys()}\n",
    "    # 使用 Dataset.from_dict() 创建 Dataset 对象\n",
    "    train_data = Dataset.from_dict(data_dict)\n",
    "\n",
    "    train_data = train_data.map(formatting_prompts_func, batched = True,)\n",
    "    return train_data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d5fe5db371445ac851c8c9fd2f6b230",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/3145 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train_data = get_train_data(\"GAOKAO\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 四、模型微调"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from trl import SFTTrainer\n",
    "from transformers import TrainingArguments\n",
    "from unsloth import is_bfloat16_supported"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "模型转为微调模式"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Unsloth 2025.2.9 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.\n"
     ]
    }
   ],
   "source": [
    "model = FastLanguageModel.get_peft_model(\n",
    "    model,\n",
    "    r=16,  \n",
    "    target_modules=[\n",
    "        \"q_proj\",\n",
    "        \"k_proj\",\n",
    "        \"v_proj\",\n",
    "        \"o_proj\",\n",
    "        \"gate_proj\",\n",
    "        \"up_proj\",\n",
    "        \"down_proj\",\n",
    "    ],\n",
    "    lora_alpha=16,\n",
    "    lora_dropout=0,  \n",
    "    bias=\"none\",  \n",
    "    use_gradient_checkpointing=\"unsloth\",  # True or \"unsloth\" for very long context\n",
    "    random_state=1290,\n",
    "    use_rslora=False,  \n",
    "    loftq_config=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义训练函数各个超参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bea9275fc7464c348f0491acafd2819f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map (num_proc=2):   0%|          | 0/3145 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer = SFTTrainer(\n",
    "    model=model,\n",
    "    tokenizer=tokenizer,\n",
    "    train_dataset=train_data,\n",
    "    dataset_text_field=\"text\",\n",
    "    max_seq_length=max_seq_length,\n",
    "    dataset_num_proc=2,\n",
    "    args=TrainingArguments(\n",
    "        per_device_train_batch_size=2,\n",
    "        num_train_epochs=3,\n",
    "        gradient_accumulation_steps=4,\n",
    "        # Use num_train_epochs = 1, warmup_ratio for full training runs!\n",
    "        warmup_steps=4,\n",
    "        learning_rate=2e-4,\n",
    "        fp16=not is_bfloat16_supported(),\n",
    "        bf16=is_bfloat16_supported(),\n",
    "        logging_steps=10,\n",
    "        optim=\"adamw_8bit\",\n",
    "        weight_decay=0.01,\n",
    "        lr_scheduler_type=\"linear\",\n",
    "        seed=1291,\n",
    "        output_dir=\"outputs\",\n",
    "    ),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练过程损失上传到wandb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wandb.login(key=\"03df53261308ddd21901480d7befd1ad4e4de221\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1\n",
      "   \\\\   /|    Num examples = 3,145 | Num Epochs = 3\n",
      "O^O/ \\_/ \\    Batch size per device = 2 | Gradient Accumulation steps = 4\n",
      "\\        /    Total batch size = 8 | Total steps = 1,179\n",
      " \"-____-\"     Number of trainable parameters = 41,943,040\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1179' max='1179' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1179/1179 1:18:29, Epoch 2/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>2.004200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>20</td>\n",
       "      <td>1.476900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>30</td>\n",
       "      <td>1.303000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>40</td>\n",
       "      <td>1.383200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>50</td>\n",
       "      <td>1.279000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>60</td>\n",
       "      <td>1.356900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>70</td>\n",
       "      <td>1.212300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>80</td>\n",
       "      <td>1.265000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>90</td>\n",
       "      <td>1.376800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>1.109200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>110</td>\n",
       "      <td>1.215200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>120</td>\n",
       "      <td>1.191500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>130</td>\n",
       "      <td>1.112000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>140</td>\n",
       "      <td>1.132500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>150</td>\n",
       "      <td>1.237600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>160</td>\n",
       "      <td>1.123500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>170</td>\n",
       "      <td>1.033100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>180</td>\n",
       "      <td>1.199000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>190</td>\n",
       "      <td>1.065400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>1.139500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>210</td>\n",
       "      <td>1.204100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>220</td>\n",
       "      <td>1.090700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>230</td>\n",
       "      <td>1.130500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>240</td>\n",
       "      <td>1.033800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>250</td>\n",
       "      <td>1.253200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>260</td>\n",
       "      <td>1.051500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>270</td>\n",
       "      <td>1.011200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>280</td>\n",
       "      <td>1.114300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>290</td>\n",
       "      <td>1.160100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>1.116400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>310</td>\n",
       "      <td>1.152500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>320</td>\n",
       "      <td>1.154100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>330</td>\n",
       "      <td>0.970500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>340</td>\n",
       "      <td>1.056200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>350</td>\n",
       "      <td>1.112600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>360</td>\n",
       "      <td>0.976900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>370</td>\n",
       "      <td>1.084700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>380</td>\n",
       "      <td>1.114500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>390</td>\n",
       "      <td>1.062100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.851000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>410</td>\n",
       "      <td>0.930100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>420</td>\n",
       "      <td>1.041000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>430</td>\n",
       "      <td>0.880800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>440</td>\n",
       "      <td>0.891900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>450</td>\n",
       "      <td>1.010200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>460</td>\n",
       "      <td>0.924500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>470</td>\n",
       "      <td>1.035200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>480</td>\n",
       "      <td>0.886000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>490</td>\n",
       "      <td>1.057600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.848200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>510</td>\n",
       "      <td>1.020600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>520</td>\n",
       "      <td>0.875900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>530</td>\n",
       "      <td>0.843200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>540</td>\n",
       "      <td>1.018300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>550</td>\n",
       "      <td>0.920200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>560</td>\n",
       "      <td>0.988900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>570</td>\n",
       "      <td>1.101500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>580</td>\n",
       "      <td>0.994900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>590</td>\n",
       "      <td>0.923800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.893200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>610</td>\n",
       "      <td>0.934100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>620</td>\n",
       "      <td>0.986000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>630</td>\n",
       "      <td>0.928200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>640</td>\n",
       "      <td>1.032500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>650</td>\n",
       "      <td>0.978100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>660</td>\n",
       "      <td>0.896300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>670</td>\n",
       "      <td>0.746000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>680</td>\n",
       "      <td>0.803900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>690</td>\n",
       "      <td>1.013300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.910200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>710</td>\n",
       "      <td>0.955400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>720</td>\n",
       "      <td>0.928600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>730</td>\n",
       "      <td>0.923400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>740</td>\n",
       "      <td>0.908500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>750</td>\n",
       "      <td>0.889300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>760</td>\n",
       "      <td>0.934300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>770</td>\n",
       "      <td>0.955300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>780</td>\n",
       "      <td>0.873600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>790</td>\n",
       "      <td>0.945100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.758200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>810</td>\n",
       "      <td>0.836700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>820</td>\n",
       "      <td>0.766700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>830</td>\n",
       "      <td>0.761000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>840</td>\n",
       "      <td>0.845200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>850</td>\n",
       "      <td>0.745500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>860</td>\n",
       "      <td>0.691600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>870</td>\n",
       "      <td>0.633100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>880</td>\n",
       "      <td>0.743600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>890</td>\n",
       "      <td>0.743900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.718700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>910</td>\n",
       "      <td>0.736700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>920</td>\n",
       "      <td>0.690200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>930</td>\n",
       "      <td>0.762800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>940</td>\n",
       "      <td>0.695400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>950</td>\n",
       "      <td>0.874800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>960</td>\n",
       "      <td>0.726000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>970</td>\n",
       "      <td>0.758100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>980</td>\n",
       "      <td>0.662100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>990</td>\n",
       "      <td>0.765300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.751100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1010</td>\n",
       "      <td>0.715000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1020</td>\n",
       "      <td>0.813400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1030</td>\n",
       "      <td>0.749200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1040</td>\n",
       "      <td>0.661200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1050</td>\n",
       "      <td>0.742400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1060</td>\n",
       "      <td>0.648200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1070</td>\n",
       "      <td>0.767100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1080</td>\n",
       "      <td>0.751900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1090</td>\n",
       "      <td>0.863400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.821900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1110</td>\n",
       "      <td>0.799200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1120</td>\n",
       "      <td>0.777300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1130</td>\n",
       "      <td>0.796200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1140</td>\n",
       "      <td>0.658300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1150</td>\n",
       "      <td>0.911900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1160</td>\n",
       "      <td>0.837400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1170</td>\n",
       "      <td>0.676600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer_stats = trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试训练后的模型性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt_style = \"\"\"\n",
    "    ### 提示:\n",
    "    你是一个对于解答高考题目有丰富经验的专家，现在有人问你关于{}。\n",
    "    请回答下面的问题，在回答问题之前请给出逐步的推理过程。\n",
    "\n",
    "    ### 问题：\n",
    "    {}\n",
    "\n",
    "    ### 回答：\n",
    "    <think>{}\n",
    "    \"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"3. (5 分) 曲线 $y=\\\\frac{x}{x+2}$ 在点 $(-1,-1)$ 处的切线方程为（ $）$\\nA. $y=2 x+1$\\nB. $y=2 x-1$\\nC. $y=-2 x-3$\\nD. $y=-2 x-2$\\n\"\n",
    "questype = \"2010-2022_Math_I_MCQs\"\n",
    "\n",
    "FastLanguageModel.for_inference(model)  # Unsloth has 2x faster inference!\n",
    "inputs = tokenizer([prompt_style.format(questype ,question, \"\")], return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(\n",
    "    input_ids=inputs.input_ids,\n",
    "    attention_mask=inputs.attention_mask,\n",
    "    max_new_tokens=1200,\n",
    "    use_cache=True,\n",
    ")\n",
    "response = tokenizer.batch_decode(outputs)\n",
    "print(response[0].split(\"### 回答：\")[1])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 五、保存模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_save = \"Ds_Llama8B_GAOKAO\"\n",
    "model.save_pretrained(model_save) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载训练好的模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from peft import PeftModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "==((====))==  Unsloth 2025.2.9: Fast Llama patching. Transformers: 4.48.3.\n",
      "   \\\\   /|    GPU: NVIDIA GeForce RTX 4090 D. Max memory: 23.643 GB. Platform: Linux.\n",
      "O^O/ \\_/ \\    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0\n",
      "\\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]\n",
      " \"-____-\"     Free Apache license: http://github.com/unslothai/unsloth\n",
      "Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0700aa1e85e4f779a0aa1f8f1fba3c4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./autodl-fs/DeepSeek-R1-Distill-Llama-8B does not have a padding token! Will use pad_token = <|finetune_right_pad_id|>.\n"
     ]
    }
   ],
   "source": [
    "#DeepSeek-R1-Distill-Llama-8B 微调后更适用于中国高考问答\n",
    "base_model, tokenizer = FastLanguageModel.from_pretrained(\n",
    "    model_name = \"./autodl-fs/DeepSeek-R1-Distill-Llama-8B\",\n",
    "    max_seq_length = max_seq_length,\n",
    "    dtype = dtype,\n",
    "    load_in_4bit = load_in_4bit,\n",
    "    device_map={\"\": device},  # 将所有参数加载到指定设备\n",
    ")\n",
    "\n",
    "\n",
    "# 加载 LoRA 适配器，加入微调后的变化\n",
    "model = PeftModel.from_pretrained(\n",
    "    base_model,\n",
    "    \"./autodl-fs/outputs/outputs/checkpoint-1179\",\n",
    "    adapter_name=\"lora_adapter\"\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "问答"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    <think>\n",
      "     解: 曲线 $y=\\frac{x}{x+2}$, 可得 $y^{\\prime}=\\frac{2}{(x+2)^{2}}$, 则曲线在点 $(-1,-1)$ 处的切线斜率为 $\\frac{2}{4}=0.5$,\n",
      "\n",
      "则曲线在点 $(-1,-1)$ 处的切线方程为 $y+1=0.5(x+1)$, 即 $y=0.5 x+0$.\n",
      "\n",
      "故选: A.\n",
      "\n",
      "    </think>\n",
      "    <answer>\n",
      "    A\n",
      "    </answer>\n",
      "    <｜end▁of▁sentence｜>\n"
     ]
    }
   ],
   "source": [
    "question = \"3. (5 分) 曲线 $y=\\\\frac{x}{x+2}$ 在点 $(-1,-1)$ 处的切线方程为（ $）$\\nA. $y=2 x+1$\\nB. $y=2 x-1$\\nC. $y=-2 x-3$\\nD. $y=-2 x-2$\\n\"\n",
    "questype = \"2010-2022_Math_I_MCQs\"\n",
    "\n",
    "FastLanguageModel.for_inference(model)  # Unsloth has 2x faster inference!\n",
    "inputs = tokenizer([prompt_style.format(questype ,question, \"\")], return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(\n",
    "    input_ids=inputs.input_ids,\n",
    "    attention_mask=inputs.attention_mask,\n",
    "    max_new_tokens=1200,\n",
    "    use_cache=True,\n",
    ")\n",
    "response = tokenizer.batch_decode(outputs)\n",
    "print(response[0].split(\"### 回答：\")[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "    <think>\n",
      "     【解答】解: A、惯性是物体保持原来运动状态的性质, 物体抵抗运动状态变 化的性质是惯性, 故 $\\mathrm{A}$ 正确;\n",
      "\n",
      "B、惯性是物体保持原来运动状态的性质, 物体在没有力的作用下, 可以继续以 原来速度沿原来方向运动, 故 B 错误;\n",
      "\n",
      "C、惯性是物体保持原来运动状态的性质, 故 C 错误;\n",
      "\n",
      "D、惯性是物体保持原来运动状态的性质, 运动物体如果没有受到力的作用, 将 继续以同一速度沿同一直线运动, 故 D 正确。\n",
      "\n",
      "故选: AD。\n",
      "\n",
      "    </think>\n",
      "    <answer>\n",
      "    AD\n",
      "    </answer>\n",
      "    <｜end▁of▁sentence｜>\n"
     ]
    }
   ],
   "source": [
    "question = \"1. (3 分) 伽利略根据小球在斜面上运动的实验和理想实验, 提出了惯性的概 念, 从而奠定了牛顿力学的基础. 早期物理学家关于惯性有下列说法, 其中 正确的是（） A. 物体抵抗运动状态变化的性质是惯性 B. 没有力作用, 物体只能处于静止状态 C. 行星在圆周轨道上保持匀速率运动的性质是惯性 D. 运动物体如果没有受到力的作用, 将继续以同一速度沿同一直线运动\"\n",
    "questype = \"2010-2022_Physics_MCQs\"\n",
    "\n",
    "FastLanguageModel.for_inference(model)  # Unsloth has 2x faster inference!\n",
    "inputs = tokenizer([prompt_style.format(questype ,question, \"\")], return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "outputs = model.generate(\n",
    "    input_ids=inputs.input_ids,\n",
    "    attention_mask=inputs.attention_mask,\n",
    "    max_new_tokens=1200,\n",
    "    use_cache=True,\n",
    ")\n",
    "response = tokenizer.batch_decode(outputs)\n",
    "print(response[0].split(\"### 回答：\")[1])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
